# What's this ? 
A example codes for "Understanding MLPMixer as a wide and sparse MLP"



# Packages
- bmlp/
    - sc_balance.py  ### codes for solving equations on S for given C and \Omega.
    - models/
        - conjugation.py  ### codes for RP tools
        - masked_mlp.py   
        - mlpmixer.py     ### codes for MLP-Mixer and S-Mixer

- datasets/ ##  STL-10 modified in download part from STL-10 from torchvision 
    - stl10.py

- mpi/  ## For parallel training on ImageNet. It requires openmpi.
    - timm_main.py
    - timm_main_nn.py

- Pipfile  
- README.md

- timm/ ## a partially modifled timm library

- train_fixed_patch_supp.py ### example code for training MLP-Mixer, S-Mixer and MaskedMLP on CIFAR10,CIFAR100, STL10. 



# Install
```bash
pip install --user pipenv 
pipenv install 
```

#  Example
MLPMixer With random permutation:
```bash
python train_fixed_patch.py  --L 8 --dim 64 --num_connection 262144 --fix 2  --net MLPMixer  --permute 2  
```
MLPMixer without permutation:
```bash
python train_fixed_patch.py  --L 8 --dim 64 --num_connection 262144 --fix 2   --net MLPMixer --expansion_factor 4  --permute 0
```

SMixer with random permutation
```bash
python train_fixed_patch.py  --L 8 --dim 64 --num_connection 262144 --fix 2  --net MLPMixer  --expansion_factor -1 --permute 2  
```
    without permutation
```bash
python train_fixed_patch.py  --L 8 --dim 64 --num_connection 262144 --fix 2   --net MLPMixer --expansion_factor -1 --permute 0
```

MaskedMLP:
```bash
python train_fixed_patch.py  --lr 1e-2 --L 8 --dim 64 --num_connection 262144 --fix 2  --net MaskedMLP --freezing_rate 0.5  
```



## ImageNet-1k
```bash
cd mpi
```
### Requires:
gcc/12.2.0 
python/3.10.10 
cuda/11.3.1 
cudnn/8.3.3 
nccl/2.9.9-1
openmpi/4.1.3

python packages:
pip install  torch torchvision
pip install ipython
pip install wandb
pip install einops
pip install tqdm
pip install tensorboard
pip install mpi4py

### Example
Training on ImageNet-1k with our codes requires 32 nodes of 4-GPUs.
Data parallel training on 128 gpus with mini-batch size 4096 :
```bash
mpirun \
  -npernode 4 \
  -np 128 \
  -x PATH \
  -x LIBRARY_PATH \
  -x LD_LIBRARY_PATH \
  -x CUDA_HOME \
  -x CUDA_CACHE_DISABLE=1 \
  python timm_main.py --model 'MLPMixer' --data-dir YOUR-IMAGENET-DIR \
  --epochs 300\
  --batch_size_training 32 --batch_size_test 32  \
  --flag_nccl 1 --flag_wandb 1 --loop 4 \
  --no-prefetcher\
  --lr 1e-3\
  --sched-on-updates\
  --permute 0 -L 8\
```